### #########################################
### Estimates the model
### Input is the formula selected from varsel
### Output is a model
### #########################################

here::i_am("newR/05_model.R")

library(here)
library(brms)
library(tidyverse)
library(broom)
library(modelsummary)
library(knitr)

source(here::here("newR", "00_helpers.R"))

dat <- readRDS(here::here("working",
                          "tidy_model_data.rds"))

load(here::here("working",
                "varsel_formulas.rdata"))

pvars <- c("Con_BE", "Lab_BE", "Lib_BE", "Nat_BE", "Oth2_BE")
for (p in pvars) { 
    dat[,p] <- replace(dat[,p],
                       dat[,p] <= 0,
                       1 / 40000)
}

### Make sure everything sums to 1
dat[,pvars] <- dat[,pvars] / rowSums(dat[,pvars])

### Create a matrix as dep. var.
dat$y <- with(dat,
              cbind(Con_BE, Lab_BE, Lib_BE, Nat_BE, Oth2_BE))

### Establish priors
### Start with priors common to both models

common_priors <- c(prior("normal(0, 0.429)", class = "Intercept", dpar = "muLabBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muLibBE"),
               prior("normal(-2.457, 0.714)", class = "Intercept", dpar = "muNatBE"),
               prior("normal(-0.95, 0.584)", class = "Intercept", dpar = "muOth2BE"),
               prior("normal(11.5, 3.25)", class = "phi"))

### Labour Poll Change is common to both models
common_priors <- c(common_priors,
               prior("normal(0.241, 0.12)", coef = "LabPollChg_sc", dpar = "muLabBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muLibBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muNatBE"),
               prior("normal(0.06, 0.0905)", coef = "LabPollChg_sc", dpar = "muOth2BE"))

### Now set up priors
### Start with the CV-minimizing model first
priors_min <- common_priors
coefs <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(coef)

coefs <- setdiff(coefs,
                 "LabPollChg_sc")

dpars <- get_prior(f_min,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(dpar) %>%
    unique()

for (k in coefs) {
    for (p in dpars) {
        short_party <- sub("mu(...).?BE", "\\1", p)
        ## If it's a candidacy and the it concerns the party we're modelling
        if (grepl("Cand_BE", k) & grepl(short_party, k)) {
            prior_string <- "normal(0, 1)"
        } else {
            prior_string <- "normal(0, 0.25)"
        }
        
        priors_min <- c(priors_min,
                        set_prior(prior_string, class = "b",
                                  coef = k,
                                  dpar = p))
    }
}

### Now do the 1SE model
priors_1se <- common_priors
coefs <- get_prior(f_1se,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(coef)

coefs <- setdiff(coefs,
                 "LabPollChg_sc")

dpars <- get_prior(f_1se,
          family = dirichlet(),
          data = dat) %>%
    filter(class == "b") %>%
    filter(coef != "") %>%
    pull(dpar) %>%
    unique()

for (k in coefs) {
    for (p in dpars) {
        short_party <- sub("mu(...).?BE", "\\1", p)
        ## If it's a candidacy and the it concerns the party we're modelling
        if (grepl("Cand_BE", k) & grepl(short_party, k)) {
            prior_string <- "normal(0, 1)"
        } else {
            prior_string <- "normal(0, 0.25)"
        }
        
        priors_1se <- c(priors_1se,
                        set_prior(prior_string, class = "b",
                                  coef = k,
                                  dpar = p))
    }
}



### Estimate the model which minimizes CV error
model_file <- here::here("working",
                         "model_min.rds")

if (file.exists(model_file)) {
    mod <- readRDS(file = model_file)
} else { 
    mod <- brm(f_min,
               family = dirichlet(),
               prior = priors_min,
               data = dat,
               control = list(adapt_delta = 0.95,
                              max_treedepth = 11),
               cores = 5, chains = 5,
               seed = 43145)
    saveRDS(mod, file = model_file)
    
}

### Estimate the model which minimizes CV error minus 1 SE
model_file <- here::here("working",
                         "model_1se.rds")

if (file.exists(model_file)) {
    mod <- readRDS(file = model_file)
} else { 
    mod <- brm(f_1se,
               family = dirichlet(),
               prior = priors_1se,
               data = dat,
               control = list(adapt_delta = 0.95,
                              max_treedepth = 11),
               cores = 5, chains = 5,
               seed = 43145)
    saveRDS(mod, file = model_file)
    
}

### ###################################################
### Start work on model evaluation
### ###################################################

### Change this line depending on the output you want.
model_file <- here::here("working",
                         "model_min.rds")

mod <- readRDS(model_file)

### Get candidate matrix
candvars <- paste0(c("Con", "Lab", "Lib", "Nat", "Oth"),
                  "Cand_BE")
candmat <- as.matrix(dat[,candvars])

### Replace adjusted outcome values
y <- y0 <- dat$y
y[which(candmat == 0, arr.ind = TRUE)] <- 0

### Are there rows which are now very far from zero?

### Zero out fitted values
f <- fitted(mod, summary = FALSE)
pp <- pp0 <- posterior_predict(mod)
f0 <- f

## Zero out
for (i in 1:dim(f)[1]) {
    x <- f[i,,]
    x[which(candmat == 0, arr.ind = TRUE)] <- 0
    x <- x / rowSums(x)
    f[i,,] <- x
}

for (i in 1:dim(pp)[1]) {
    x <- pp[i,,]
    x[which(candmat == 0, arr.ind = TRUE)] <- 0
    x <- x / rowSums(x)
    pp[i,,] <- x
}

### Get different measures out
out <- data.frame(Statistic = c("Seats correctly predicted",
                                "Multiclass Brier score",
                                "Mean absolute error",
                                "Median absolute error",
                                "Predictions inside 95% interval"),
                  `In-sample` = c(100 * get_pcp3(f, y),
                                  100 * get_brier3(f, y),
                                  100 * get_mae3(f, y),
                                  100 * get_mae3(f, y, type = "median"),
                                  100 * get_calibration3(pp, y)))

saveRDS(out,
        file = here::here("working",
                          "in-sample-perf.rds"))


### What is the worst predicted, at mean?
fbar <- apply(f, c(2, 3), mean)
worst <- which.max(rowMeans(abs(fbar - y)))
dat[worst, c("Name", "Date")]
fbar[worst,]
y[worst,]

### What about predicted versus actual winner
party.names <- c("Con", "Lab", "Lib", "Nat", "Oth")

pred_winner <- apply(fbar, 1, which.max)
pred_winner <- party.names[pred_winner]
actual_winner <- apply(dat[,pvars], 1, which.max)
actual_winner <- party.names[actual_winner]

